import torch
from torch.optim.lr_scheduler import _LRScheduler, StepLR

class WarmupStepLR(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs, step_size, gamma, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.step_size = step_size
        self.gamma = gamma
        super(WarmupStepLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        # 获取当前epoch
        epoch = self.last_epoch
        # 如果是warmup阶段，线性增加学习率
        if epoch < self.warmup_epochs:
            warmup_lr = [base_lr * (epoch + 1) / self.warmup_epochs for base_lr in self.base_lrs]
            return warmup_lr
        else:
            # 如果超过了warmup阶段，使用StepLR进行学习率衰减
            return [base_lr * self.gamma ** (epoch // self.step_size) for base_lr in self.base_lrs]


def get_lr_scheduler(args, optimizer):
    if args.lr_scheduler=='WarmupStepLR':
        lr_scheduler = WarmupStepLR(optimizer, warmup_epochs=args.warmup_epochs, step_size=args.step_size, gamma=args.gamma)
    elif args.lr_scheduler=="StepLR":
        lr_scheduler=torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, args.gamma)
    elif args.lr_scheduler=='MultiStepLR':
        lr_scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestone, args.gamma)
    elif args.lr_scheduler=='ExponentialLR':
        lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, args.gamma)
    elif args.lr_scheduler=='LinearLR':
        lr_scheduler=torch.optim.lr_scheduler.LinearLR(optimizer, 1, args.factor, args.total_iters)
    elif args.lr_scheduler=='CyclicLR':
        lr_scheduler=torch.optim.lr_scheduler.CyclicLR(optimizer, args.min_lr, args.lr, args.step_size_up, args.step_size_down, args.cycle_momentum)
    elif args.lr_scheduler=='OneCycleLR':
        lr_scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.pct_start, args.epochs, args.div_factor, args.final_div_factor)
    elif args.lr_scheduler=='CosineAnnealingLR':
        lr_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, args.min_lr)
    elif args.lr_scheduler=='SequentialLR':
        lr_scheduler=torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9),
                                                                    torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1,
                                                                                          end_factor=0.1,
                                                                                          total_iters=80)],
                                             milestones=[50])
    elif args.lr_scheduler=='ChainedScheduler':
        lr_scheduler=torch.optim.lr_scheduler.ChainedScheduler( [torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.5, total_iters=10),
         torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)])
    elif args.lr_scheduler=='ConstantLR':
        lr_scheduler=torch.optim.lr_scheduler.ConstantLR(optimizer, args.factor, args.total_iters)
    elif args.lr_scheduler==None:
        lr_scheduler=None
    else:
        raise ValueError('unknown learning rate scheduler: {}'.format(args.lr_scheduler))
    return lr_scheduler